package cn.zhangfusheng.elasticsearch.template;

import cn.zhangfusheng.elasticsearch.annotation.dsl.DslIndex;
import cn.zhangfusheng.elasticsearch.constant.ElasticSearchConstant;
import cn.zhangfusheng.elasticsearch.exception.GlobalSystemException;
import cn.zhangfusheng.elasticsearch.model.page.PageRequest;
import cn.zhangfusheng.elasticsearch.scan.ElasticSearchEntityRepositoryDetail;
import cn.zhangfusheng.elasticsearch.thread.ThreadLocalDetail;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.client.core.CountRequest;
import org.elasticsearch.client.core.CountResponse;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;

/**
 * @author fusheng.zhang
 * @date 2022-04-26 10:56:23
 */
public interface ElasticSearchTemplateApi extends Template {

    Logger log = LoggerFactory.getLogger(ElasticSearchTemplateApi.class);

    default Object search(
            ElasticSearchEntityRepositoryDetail entityRepositoryDetail, Method method,
            SearchRequest searchRequest, PageRequest pageRequest) throws IOException {
        List<SearchHit> searchHits = new ArrayList<>();
        SearchResponse searchResponse;
        if (Objects.nonNull(pageRequest)) {
            searchResponse = this.searchWithPage(searchRequest, pageRequest, hs -> searchHits.addAll(Arrays.asList(hs)));
            return entityRepositoryDetail.analysisSearchResponse(searchResponse, searchHits, method, pageRequest.getSkipTotal());
        } else {
            searchResponse = this.search(searchRequest, hs -> searchHits.addAll(Arrays.asList(hs)));
            return entityRepositoryDetail.analysisSearchResponse(searchResponse, searchHits, method, 0);
        }
    }

    /**
     * 执行 searchRequest
     * @param searchRequest
     * @param consumer
     * @return
     * @throws IOException
     */
    default SearchResponse search(SearchRequest searchRequest, Consumer<SearchHit[]> consumer) throws IOException {
        boolean trackTotalHits = ThreadLocalDetail.trackTotalHits();
        long totalNum = ElasticSearchConstant.MAX_DOC_SIZE.longValue();
        // 查询全部数据
        if (trackTotalHits) {
            // 获取符合全部数据的总条数
            CountRequest countRequest =
                    new CountRequest(searchRequest.indices(), searchRequest.source()).routing(searchRequest.routing());
            totalNum = this.count(countRequest);
            // 总条数大于最大查询数据则启用滚动查询
            if (totalNum > ElasticSearchConstant.MAX_DOC_SIZE) {
                // 配置滚动查询每次查询的数量
                Integer scrollSize = ThreadLocalDetail.scrollSize().orElse(ElasticSearchConstant.THRESHOLD_DOC_SIZE);
                searchRequest.source().size(scrollSize);
                return this.searchWitchScroll(searchRequest, consumer, -1, null);
            }
        }
        SearchResponse searchResponse = null;
        // 设置最大查询数量
        searchRequest.source().size(Long.valueOf(totalNum).intValue());
        // 当总查询数量大于 THRESHOLD_DOC_SIZE 的2倍时,启用分批次,循环查询
        if (totalNum > ElasticSearchConstant.THRESHOLD_DOC_SIZE * 2) {
            searchRequest.source().size(ElasticSearchConstant.THRESHOLD_DOC_SIZE);
            long loopNum = totalNum % ElasticSearchConstant.THRESHOLD_DOC_SIZE == 0
                    ? ElasticSearchConstant.DEFAULT_LOOP_PAGE_NUM
                    : totalNum / ElasticSearchConstant.THRESHOLD_DOC_SIZE + 1;
            for (int i = 0; i < loopNum; i++) {
                searchRequest.source().from(i * ElasticSearchConstant.THRESHOLD_DOC_SIZE);
                log.debug("GET {}/_search?routing={} {}",
                        String.join(",", searchRequest.indices()), searchRequest.routing(), searchRequest.source());
                searchResponse = restHighLevelClient().search(searchRequest, ThreadLocalDetail.requestOptions());
                SearchHit[] hits = searchResponse.getHits().getHits();
                if (Objects.nonNull(hits) && hits.length > 0) {
                    consumer.accept(hits);
                    // 返回值的长度小于分页的大小,没有下一页,break;
                    if (hits.length < ElasticSearchConstant.THRESHOLD_DOC_SIZE) break;
                } else {
                    break;
                }
            }
        } else {
            log.debug("GET {}/_search?routing={} {}",
                    String.join(",", searchRequest.indices()), searchRequest.routing(), searchRequest.source());
            searchResponse = restHighLevelClient().search(searchRequest, ThreadLocalDetail.requestOptions());
            SearchHit[] hits = searchResponse.getHits().getHits();
            if (Objects.nonNull(hits) && hits.length > 0) consumer.accept(hits);
        }
        return searchResponse;
    }

    /**
     * 分页查询
     * @param searchRequest
     * @param pageRequest
     * @param consumer
     * @return
     */
    default SearchResponse searchWithPage(SearchRequest searchRequest, PageRequest pageRequest, Consumer<SearchHit[]> consumer) {
        try {
            Integer size = pageRequest.getSize(), from = pageRequest.getFrom();
            SearchSourceBuilder searchSourceBuilder = searchRequest.source();
            // 设置分页大小
            searchSourceBuilder.size(size);
            // 本地查询的最大数据位置
            int endDbIndex = from + size;
            if (endDbIndex > ElasticSearchConstant.MAX_DOC_SIZE) { // 超过阀值限制,启用滚动查询或者伪跳页查询
                if (ObjectUtils.isNotEmpty(pageRequest.getSearchAfter())) { // 滚动查询
                    searchSourceBuilder.searchAfter(pageRequest.getSearchAfter());
                } else if (ObjectUtils.isNotEmpty(pageRequest.getSkipPage())) {// 伪跳页查询,向右跳页,向左跳页是向右跳页
                    // 计算开始位置
                    from = from - pageRequest.getSkipTotal();
                    // 开始位置 > 最大查询阀值
                    if (from > ElasticSearchConstant.MAX_DOC_SIZE) {
                        // 当前跳页查询能查询到的最大数据位置
                        int dbMaxIndex = pageRequest.getSkipTotal() + ElasticSearchConstant.MAX_DOC_SIZE;
                        // 计算最大跳页位置
                        int maxPageNumber = dbMaxIndex % size == 0 ? dbMaxIndex / size : dbMaxIndex / size + 1;
                        throw new GlobalSystemException("超过最大查询上限,最多跳转到{}页", maxPageNumber);
                    }
                    // 跳页逻辑
                    BoolQueryBuilder skipPageQueryBuilder = QueryBuilders.boolQuery();
                    pageRequest.getSkipPage().forEach(skipPage -> skipPageQueryBuilder.must(skipPage.rangeQuery()));
                    BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
                    boolQueryBuilder.must(searchSourceBuilder.query()).must(skipPageQueryBuilder);
                    searchSourceBuilder.query(boolQueryBuilder);
                    searchSourceBuilder.from(from);
                } else {
                    throw new GlobalSystemException("超出最大查询数据限制,必须设置searchAfter 或者 skipPage参数");
                }
            } else {
                searchSourceBuilder.from(from);
            }
            //  添加_id 排序 统计全部条数
            if (CollectionUtils.isEmpty(searchSourceBuilder.sorts())) {
                searchSourceBuilder.sort(ElasticSearchConstant.SORT_ID, SortOrder.ASC);
            }
            // 设置track total hots
            boolean trackTotalHits = ThreadLocalDetail.trackTotalHits();
            if (trackTotalHits) searchSourceBuilder.trackTotalHits(Boolean.TRUE);
            //
            log.debug("GET {}/_search?routing={} {}",
                    String.join(",", searchRequest.indices()), searchRequest.routing(), searchSourceBuilder);
            SearchResponse searchResponse = restHighLevelClient().search(searchRequest, ThreadLocalDetail.requestOptions());
            consumer.accept(searchResponse.getHits().getHits());
            return searchResponse;
        } catch (IOException e) {
            throw new GlobalSystemException(e);
        }
    }

    /**
     * 滚动查询
     * @param searchRequest
     * @param consumer
     * @param runNum        执行次数 <0:不限次数
     * @param scrollId      游标id
     */
    default SearchResponse searchWitchScroll(
            SearchRequest searchRequest, Consumer<SearchHit[]> consumer, int runNum, String scrollId) {
        SearchResponse searchResponse = null;
        try {
            // 滚动查询缓存时长
            Long keepAlive = ThreadLocalDetail.keepAlive().orElse(5000L);
            TimeValue timeValue = TimeValue.timeValueMillis(keepAlive);
            SearchScrollRequest scrollRequest = new SearchScrollRequest(scrollId);
            if (StringUtils.isBlank(scrollId)) {
                searchRequest.scroll(timeValue);
                log.debug("GET {}/_search?scroll={}&routing={} {}",
                        String.join(",", searchRequest.indices()), timeValue, searchRequest.routing(), searchRequest.source());
                searchResponse = restHighLevelClient().search(searchRequest, ThreadLocalDetail.requestOptions());
            } else {
                log.debug("POST _search/scroll {\"scroll\":\"{}\",\"scroll_id\":\"{}\",\"routing\":\"{}\"}",
                        timeValue, scrollRequest.scrollId(), searchRequest.routing());
                searchResponse = restHighLevelClient().scroll(scrollRequest, ThreadLocalDetail.requestOptions());
            }
            SearchHit[] hits = searchResponse.getHits().getHits();
            int num = 1;
            while (hits != null && hits.length > 0) {
                consumer.accept(hits);
                if (Objects.isNull(searchResponse.getScrollId()) || num++ == runNum) break;
                scrollRequest.scrollId(searchResponse.getScrollId()).scroll(timeValue);
                searchResponse = restHighLevelClient().scroll(scrollRequest, ThreadLocalDetail.requestOptions());
                hits = searchResponse.getHits().getHits();
            }
        } catch (Exception e) {
            throw new GlobalSystemException(e);
        } finally {
            if (Objects.nonNull(searchResponse)) {
                this.clearScroll(searchResponse.getScrollId(), searchRequest.routing());
            }
        }
        return searchResponse;
    }

    /**
     * 清除滚动查询
     * @param scrollId
     * @param routing
     */
    default void clearScroll(String scrollId, String routing) {
        if (StringUtils.isNotBlank(scrollId)) {
            ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
            clearScrollRequest.addScrollId(scrollId);
            try {
                restHighLevelClient().clearScroll(clearScrollRequest, ThreadLocalDetail.requestOptions());
                log.debug("DELETE _search/scroll {\"scroll_id\":\"{}\",\"routing\":\"{}\"}", scrollId, routing);
            } catch (IOException e) {
                log.error(e.getMessage(), e);
            }
        }
    }

    /**
     * 统计数据
     * @param countRequest
     * @return
     */
    default long count(CountRequest countRequest) {
        try {
            log.debug("GET {}/_search?routing={} {}",
                    String.join(",", countRequest.indices()), countRequest.routing(), countRequest.source());
            CountResponse countResponse = restHighLevelClient().count(countRequest, ThreadLocalDetail.requestOptions());
            return countResponse.getCount();
        } catch (IOException e) {
            throw new GlobalSystemException(e);
        }
    }

    /**
     * 获取 index
     * @param method
     * @param args
     * @param index
     * @return
     */
    default String[] analysisIndex(Method method, Object[] args, String index) {
        if (ElasticSearchConstant.METHOD_INDEX_CACHE.containsKey(method)) {
            return ElasticSearchConstant.METHOD_INDEX_CACHE.get(method);
        }
        synchronized (ElasticSearchConstant.METHOD_INDEX_CACHE) {
            if (ElasticSearchConstant.METHOD_INDEX_CACHE.containsKey(method)) {
                return ElasticSearchConstant.METHOD_INDEX_CACHE.get(method);
            } else {
                List<String> indices = new ArrayList<>(args.length + 1);
                if (Objects.nonNull(method)) {
                    DslIndex dslIndex = method.getAnnotation(DslIndex.class);
                    if (Objects.nonNull(dslIndex)) analysisIndex(dslIndex, indices);
                }
                if (CollectionUtils.isEmpty(indices)) indices.add(index);
                ElasticSearchConstant.METHOD_INDEX_CACHE.put(method, indices.stream().distinct().toArray(String[]::new));
            }
        }
        return ElasticSearchConstant.METHOD_INDEX_CACHE.get(method);
    }

    /**
     * 解析 DslIndex,获取 index
     * @param dslIndex
     * @param indices
     */
    default void analysisIndex(DslIndex dslIndex, List<String> indices) {
        Class<?>[] value = dslIndex.value();
        for (Class<?> indexClass : value) {
            if (indexClass.equals(Void.class)) continue;
            if (!ElasticSearchConstant.REPOSITORY_DETAIL_CACHE.containsKey(indexClass)) {
                throw new GlobalSystemException("根据{},未获取到对应的索引", indexClass.getName());
            }
            indices.add(ElasticSearchConstant.REPOSITORY_DETAIL_CACHE.get(indexClass).getSearchIndex());
        }
    }

    default String[] analysisIndex(Object[] args, String... index) {
        if (Objects.isNull(args)) return index;
        List<String> indices = new ArrayList<>(args.length);
        for (Object arg : args) {
            DslIndex dslIndex = arg.getClass().getAnnotation(DslIndex.class);
            if (Objects.nonNull(dslIndex)) this.analysisIndex(dslIndex, indices);
        }
        if (CollectionUtils.isEmpty(indices)) return index;
        return indices.stream().distinct().toArray(String[]::new);
    }
}
